#include <stdio.h>
#include <math.h>
#include "debug_utils.h"

void debug_print_norm(const char* classname, const char* TAG, float* array, int size, int rank, int step) {
    if (rank == 0) {
        double sum = 0.0;
        for (int i = 0; i < size; i++) {
            sum += array[i] * array[i];
        }
        double norm = sqrt(sum);
        printf("[%s] %s step=%d norm=%.5E\n", classname, TAG, step, norm);
        fflush(stdout);
    }
} 

void debug_print_norm_d(const char* classname, const char* TAG, double* array, int size, int rank, int step) {
    if (rank == 0) {
        double sum = 0.0;
        for (int i = 0; i < size; i++) {
            sum += array[i] * array[i];
        }
        double norm = sqrt(sum);
        printf("[%s] %s step=%d norm=%.12E\n", classname, TAG, step, norm);
        fflush(stdout);
    }
} 